# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""thor"""
from __future__ import absolute_import

import numpy as np

from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.tensor import Tensor
import mindspore.ops as ops
import mindspore.nn as nn
import mindspore.common.dtype as mstype
import mindspore.log as logger
from mindspore import _checkparam as Validator
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.nn.layer import DenseThor, Conv2dThor, EmbeddingThor, EmbeddingLookupThor
from mindspore.nn.wrap import DistributedGradReducer
from mindspore.train.train_thor.convert_utils import ConvertNetUtils
from mindspore.parallel._auto_parallel_context import auto_parallel_context

# Enumerates types of Layer
Other = -1
Conv = 1
FC = 2
Embedding = 3
LayerNorm = 4
BatchNorm = 5

op_add = P.AddN()
apply_decay = C.MultitypeFuncGraph("apply_decay")
_momentum_opt = C.MultitypeFuncGraph("momentum_opt")


@apply_decay.register("Number", "Bool", "Tensor", "Tensor")
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
    """Get grad with weight_decay."""
    if if_apply:
        return op_add((weight * weight_decay, gradient))
    return gradient


@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment):
    """Apply momentum optimizer to the weight parameter using Tensor."""
    success = True
    success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
    return success


IS_ENABLE_GLOBAL_NORM = False
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
clip_grad = C.MultitypeFuncGraph("clip_grad")
hyper_map_op = C.HyperMap()


@clip_grad.register("Number", "Number", "Tensor")
def _clip_grad(clip_type, clip_value, grad):
    """
    Clip gradients.

    Inputs:
        clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
        clip_value (float): Specifies how much to clip.
        grad (tuple[Tensor]): Gradients.

    Outputs:
        tuple[Tensor], clipped gradients.
    """
    if clip_type not in [0, 1]:
        return grad
    dt = F.dtype(grad)
    if clip_type == 0:
        new_grad = ops.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
                                     F.cast(F.tuple_to_array((clip_value,)), dt))
    else:
        new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
    return new_grad


def clip_gradient(enable_clip_grad, gradients):
    """clip gradients"""
    if enable_clip_grad:
        if IS_ENABLE_GLOBAL_NORM:
            gradients = C.clip_by_global_norm(gradients, GRADIENT_CLIP_VALUE, None)
        else:
            gradients = hyper_map_op(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), gradients)
    return gradients


C0 = 16


def _check_param(momentum, frequency, lr, cls_name):
    """Check param."""
    Validator.check_value_type("momentum", momentum, [float], cls_name)
    if isinstance(momentum, float) and momentum < 0.0:
        raise ValueError("For 'thor', the argument 'momentum' must be at least 0.0, "
                         "but got 'momentum' {}.".format(momentum))
    Validator.check_value_type("frequency", frequency, [int], cls_name)
    if isinstance(frequency, int) and frequency < 2:
        raise ValueError("For 'thor', the argument 'frequency' must be at least 2, "
                         "but got 'frequency' {}.".format(frequency))
    Validator.check_value_type("learning rate", lr, [Tensor], cls_name)


def caculate_device_shape(matrix_dim, channel, is_a):
    if is_a:
        if channel // C0 == 0:
            matrix_dim = (matrix_dim / channel) * C0
    ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim)
    return ll


def is_conv_matmul_support_shape(matrix_a_shape, matrix_g_shape):
    """is conv layer matmul support shape"""
    temp = (matrix_g_shape, matrix_a_shape)
    support_shape = [((4, 4, 16, 16), (49, 49, 16, 16)),
                     ((4, 4, 16, 16), (4, 4, 16, 16)),
                     ((4, 4, 16, 16), (36, 36, 16, 16)),
                     ((16, 16, 16, 16), (4, 4, 16, 16)),
                     ((4, 4, 16, 16), (16, 16, 16, 16)),
                     ((8, 8, 16, 16), (16, 16, 16, 16)),
                     ((8, 8, 16, 16), (72, 72, 16, 16)),
                     ((32, 32, 16, 16), (8, 8, 16, 16)),
                     ((32, 32, 16, 16), (16, 16, 16, 16)),
                     ((8, 8, 16, 16), (32, 32, 16, 16)),
                     ((16, 16, 16, 16), (32, 32, 16, 16)),
                     ((16, 16, 16, 16), (144, 144, 16, 16)),
                     ((64, 64, 16, 16), (16, 16, 16, 16)),
                     ((64, 64, 16, 16), (32, 32, 16, 16)),
                     ((16, 16, 16, 16), (64, 64, 16, 16)),
                     ((32, 32, 16, 16), (64, 64, 16, 16)),
                     ((32, 32, 16, 16), (288, 288, 16, 16)),
                     ((128, 128, 16, 16), (32, 32, 16, 16)),
                     ((128, 128, 16, 16), (64, 64, 16, 16)),
                     ((32, 32, 16, 16), (128, 128, 16, 16))]
    if temp in support_shape:
        return True
    return False


def caculate_matmul_shape(matrix_a_dim, matrix_g_dim, split_dim):
    """get matmul shape"""
    split_dima = split_dim
    split_dimg = split_dim
    if matrix_a_dim % split_dim == 0:
        batch_w = matrix_a_dim // split_dim
    else:
        if matrix_a_dim < split_dim:
            batch_w = 1
            split_dima = matrix_a_dim
        else:
            batch_w = matrix_a_dim // split_dim + 1

    if matrix_g_dim % split_dim == 0:
        batch_h = matrix_g_dim // split_dim
    else:
        if matrix_g_dim < split_dim:
            batch_h = 1
            split_dimg = matrix_g_dim
        else:
            batch_h = matrix_g_dim // split_dim + 1
    matrix_a_shape = (batch_h, batch_w, split_dima, split_dima)
    matrix_g_shape = (batch_h, split_dimg, split_dimg)
    return matrix_a_shape, matrix_g_shape


def get_layer_type_for_dense_and_conv(subcell, prefix, layertype_map):
    """get layer type for dense layer and conv layer"""
    if subcell.weight.requires_grad:
        if "rpn_with_loss.rpn_convs_list." not in prefix.lower() \
                or "rpn_with_loss.rpn_convs_list.0." in prefix.lower():
            layertype_map.append(Other)


def find_net_layertype_recur(net, layertype_map):
    """get net layer type recursively."""
    cells = net.name_cells()
    for name in cells:
        subcell = cells[name]
        prefix = subcell.param_prefix
        if subcell == net:
            continue
        elif isinstance(subcell, Conv2dThor):
            layertype_map.append(Conv)
        elif isinstance(subcell, DenseThor):
            layertype_map.append(FC)
        elif isinstance(subcell, (EmbeddingThor, EmbeddingLookupThor)):
            layertype_map.append(Embedding)
        elif isinstance(subcell, nn.LayerNorm):
            layertype_map.append(LayerNorm)
        elif isinstance(subcell, nn.BatchNorm2d):
            if subcell.gamma.requires_grad:
                layertype_map.append(BatchNorm)
        elif isinstance(subcell, (nn.Conv2d, nn.Dense, nn.Embedding, nn.Conv2dTranspose, nn.Conv1d, nn.Conv1dTranspose,
                                  nn.BatchNorm1d, nn.GroupNorm)):
            if isinstance(subcell, (nn.Dense, nn.Conv2d)):
                get_layer_type_for_dense_and_conv(subcell, prefix, layertype_map)
            else:
                layertype_map.append(Other)
        else:
            find_net_layertype_recur(subcell, layertype_map)


def get_net_layertype_mask(net):
    layertype_map = []
    find_net_layertype_recur(net, layertype_map)
    return layertype_map


def get_layer_counter(layer_type, layer_counter, params, idx):
    """get layer counter"""
    if layer_type in [Conv, FC]:
        if "bias" in params[idx].name.lower():
            layer_counter = layer_counter + 1
        else:
            if idx < len(params) - 1 and "bias" not in params[idx + 1].name.lower():
                layer_counter = layer_counter + 1
    elif layer_type in [LayerNorm, BatchNorm]:
        if "beta" in params[idx].name.lower():
            layer_counter = layer_counter + 1
    else:
        if "bias" in params[idx].name.lower():
            layer_counter = layer_counter + 1
        elif "weight" in params[idx].name.lower():
            if idx < len(params) - 1 and "bias" not in params[idx + 1].name.lower():
                layer_counter = layer_counter + 1
        else:
            layer_counter = layer_counter + 1
    return layer_counter


def thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32,
         use_nesterov=False, decay_filter=lambda x: x.name not in [], split_indices=None, enable_clip_grad=False,
         frequency=100):
    r"""
    Updates gradients by second-order algorithm--THOR.

    The updating formulas are as follows,

    .. math::
        \begin{array}{ll}
          & \textbf{Parameter:} \: \text{the learning rate } \gamma\text{, the damping parameter }\lambda \\
          & \textbf{Init:} \: \lambda \leftarrow 0 \\
          & A_{i-1}=\mathbb{E}\left[a_{i-1} a_{i-1}^{T}\right] \\
          & G_{i}=\mathbb{E}\left[D_{s_i} D_{s_i}^{T}\right] \\
          & w_{i}^{(k+1)} \leftarrow w_{i}^{(k)}-\gamma\left(\left(A_{i-1}^{(k)}+\lambda I\right)^{-1}
            \otimes\left(G_{i}^{(k)}+\lambda I\right)^{-1}\right) \nabla_{w_{i}} J^{(k)}
        \end{array}

    :math:`a_{i-1}` represents the input of :math:`i`-th layer,and which is the activations of previous layer.
    :math:`D_{s_i}` represents the derivative of the loss function of the output of the :math:`i`-th layer.
    :math:`I` represents the identity matrix.
    :math:`\lambda` represents :math:`damping`, :math:`g_i` represents gradients of the :math:`i`-th layer.
    :math:`\otimes` represents Kronecker product, :math:`\gamma` represents 'learning rate'.

    Note:
        When a parameter group is separated, 'weight_decay' of each group is applied to the corresponding parameter.
        'weight_decay' in the optimizer is applied to arguments that do not have 'beta' or 'gamma' in their name
        when the argument group is not separated.
        When separating parameter groups, set grad_centralization to True if you want to concentrate gradients,
        but concentration gradients can only be applied to parameters of the convolution layer.
        If the parameter for the unconvolutional layer is set to True, an error will be reported.
        To improve the performance of parameter groups, you can customize the order of parameters.

    Args:
        net (Cell): The training network.

        learning_rate (Tensor): A value for the learning rate.

        damping (Tensor): A value for the damping.

        momentum (float): Hyper-parameter of type float, means momentum for the moving average. It must be at least 0.0.

        weight_decay (int, float): Weight decay (L2 penalty). It must be equal to or greater than 0.0.
            Default: ``0.0`` .

        loss_scale (float): A value for the loss scale. It must be greater than 0.0. In general, use the
            default value. Default: ``1.0`` .

        batch_size (int): The size of a batch. Default: ``32`` .

        use_nesterov (bool): Enable Nesterov momentum. Default: ``False`` .

        decay_filter (function): A function to determine which layers the weight decay applied to. And it
            only works when the weight_decay > 0. Default: lambda x: x.name not in []

        split_indices (list): Set allreduce fusion strategy by A/G layer indices . Only works when distributed
            computing. ResNet50 as an example, there are 54 layers of A/G respectively, when split_indices is set
            to [26, 53], it means A/G is divided into two groups to allreduce,  one is 0~26 layer, and the other
            is 27~53. Default: ``None`` .

        enable_clip_grad (bool): Whether to clip the gradients. Default: ``False`` .

        frequency(int): The update interval of A/G and :math:`A^{-1}/G^{-1}`. When frequency equals N
            (N is greater than 1), A/G and :math:`A^{-1}/G^{-1}` will be updated every N steps,
            and other steps will use the stale A/G and :math:`A^{-1}/G^{-1}` to update weights. Default: ``100`` .

    Inputs:
        - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.

    Outputs:
        tuple[bool], all elements are True.

    Raises:
        TypeError: If `learning_rate` is not Tensor.
        TypeError: If `loss_scale`, `momentum` or `frequency` is not a float.
        TypeError: If `weight_decay` is neither float nor int.
        TypeError: If `use_nesterov` is not a bool.
        TypeError: If `frequency` is not int.
        ValueError: If `loss_scale` is less than or equal to 0.
        ValueError: If `weight_decay` or `momentum` is less than 0.
        ValueError: If `frequency` is less than 2.

    Supported Platforms:
        ``Ascend`` ``GPU``

    Examples:
        >>> import mindspore as ms
        >>> from mindspore import nn
        >>> from mindspore import Tensor
        >>>
        >>> # Define the network structure of LeNet5. Refer to
        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
        >>> net = LeNet5()
        >>> # Create the dataset taking MNIST as an example. Refer to
        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
        >>> dataset = create_dataset()
        >>> temp = Tensor([4e-4, 1e-4, 1e-5, 1e-5], mstype.float32)
        >>> optim = nn.thor(net, learning_rate=temp, damping=temp, momentum=0.9, loss_scale=128, frequency=4)
        >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
        >>> loss_scale = ms.FixedLossScaleManager(128, drop_overflow_update=False)
        >>> model = ms.Model(net, loss_fn=loss, optimizer=optim, loss_scale_manager=loss_scale, metrics={'acc'},
        ...               amp_level="O2", keep_batchnorm_fp32=False)
        >>> model = ms.ConvertModelUtils.convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=optim,
        ...                                                 loss_scale_manager=loss_scale, metrics={'acc'},
        ...                                                 amp_level="O2", keep_batchnorm_fp32=False)

    """
    context.set_context(max_call_depth=10000)
    ConvertNetUtils().convert_to_thor_net(net)
    if context.get_context("device_target") == "Ascend":
        return ThorAscend(net, learning_rate, damping, momentum, weight_decay, loss_scale, batch_size, decay_filter,
                          split_indices=split_indices, enable_clip_grad=enable_clip_grad, frequency=frequency)
    return ThorGpu(net, learning_rate, damping, momentum, weight_decay, loss_scale, batch_size,
                   use_nesterov, decay_filter, split_indices=split_indices, enable_clip_grad=enable_clip_grad,
                   frequency=frequency)


class ThorGpu(Optimizer):
    """
    ThorGpu
    """

    def __init__(self, net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32,
                 use_nesterov=False, decay_filter=lambda x: x.name not in [], split_indices=None,
                 enable_clip_grad=False, frequency=100):
        params = filter(lambda x: x.requires_grad, net.get_parameters())
        super(ThorGpu, self).__init__(learning_rate, params, weight_decay, loss_scale)
        _check_param(momentum, frequency, learning_rate, self.__class__.__name__)
        self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
        self.params = self._parameters
        self.use_nesterov = Validator.check_bool(use_nesterov)
        self.moments = self.params.clone(prefix="moments", init='zeros')
        self.hyper_map = C.HyperMap()
        self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov)
        self.net = net
        self.matrix_a_cov = ParameterTuple(filter(lambda x: 'matrix_a' in x.name, net.get_parameters()))
        self.matrix_g_cov = ParameterTuple(filter(lambda x: 'matrix_g' in x.name, net.get_parameters()))
        self.a_normalizer = ParameterTuple(filter(lambda x: 'a_normalizer' in x.name, net.get_parameters()))
        self.g_normalizer = ParameterTuple(filter(lambda x: 'g_normalizer' in x.name, net.get_parameters()))
        self.batch_size = Tensor(batch_size, mstype.float32)
        self.loss_scale = Tensor(1 / (loss_scale * loss_scale), mstype.float32)
        self.batch_size_scale = Tensor(batch_size * batch_size, mstype.float32)
        self.damping = damping
        self._define_gpu_operator()
        logger.info("matrix_a_cov len is {}".format(len(self.matrix_a_cov)))
        self.thor = True
        self.matrix_a = ()
        self.matrix_g = ()
        self.matrix_a_shape = ()
        self.thor_layer_count = 0
        self.conv_layer_count = 0
        self.weight_fim_idx_map = ()
        self.weight_conv_idx_map = ()
        self.weight_layertype_idx_map = ()
        self._process_matrix_init_and_weight_idx_map(self.net)
        self.matrix_a = ParameterTuple(self.matrix_a)
        self.matrix_g = ParameterTuple(self.matrix_g)
        self.weight_decay = weight_decay
        self.decay_flags = tuple(decay_filter(x) for x in self._parameters)
        self.update_gradient = P.UpdateThorGradient(split_dim=self.split_dim)
        self.enable_clip_grad = enable_clip_grad
        self.frequency = frequency
        self._define_gpu_reducer(split_indices)

    def get_frequency(self):
        """get thor frequency"""
        return self.frequency

    def _define_gpu_operator(self):
        """define gpu operator"""
        self.transpose = P.Transpose()
        self.shape = P.Shape()
        self.reshape = P.Reshape()
        self.matmul = P.MatMul()
        self.assign = P.Assign()
        self.mul = P.Mul()
        self.gather = P.Gather()
        self.one = Tensor(1, mstype.int32)
        self.feature_map = Tensor(1.0, mstype.float32)
        self.axis = 0
        self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
        self.cast = P.Cast()
        self.sqrt = P.Sqrt()
        self.eye = P.Eye()
        self.split_dim = 128
        self.embedding_cholesky = P.CholeskyTrsm()
        self.cholesky = P.CholeskyTrsm(split_dim=self.split_dim)
        self.vector_matmul = P.BatchMatMul(transpose_a=True)
        self.reduce_sum = P.ReduceSum(keep_dims=False)
        self.inv = P.Reciprocal()
        self.square = P.Square()
        self.expand = P.ExpandDims()

    def _define_gpu_reducer(self, split_indices):
        """define gpu reducer"""
        self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
        self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
        if self.is_distributed:
            mean = _get_gradients_mean()
            degree = _get_device_num()
            if not split_indices:
                self.split_indices = [len(self.matrix_a_cov) - 1]
            else:
                self.split_indices = split_indices
            auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum6")
            auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum8")
            self.grad_reducer_a = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=6)
            self.grad_reducer_g = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=8)

    def _process_matrix_init_and_weight_idx_map(self, net):
        """for GPU, process matrix init shape, and get weight idx map"""
        layer_type_map = get_net_layertype_mask(net)
        layer_counter = 0
        for idx in range(len(self.params)):
            layer_type = layer_type_map[layer_counter]
            weight = self.params[idx]
            weight_shape = self.shape(weight)
            if layer_type in [Conv, FC] and "bias" not in self.params[idx].name.lower():
                in_channels = weight_shape[1]
                out_channels = weight_shape[0]
                matrix_a_dim = in_channels
                if layer_type == Conv:
                    matrix_a_dim = in_channels * weight_shape[2] * weight_shape[3]
                matrix_g_dim = out_channels
                matrix_a_shape, matrix_g_shape = caculate_matmul_shape(matrix_a_dim, matrix_g_dim, self.split_dim)
                matrix_a_inv = Parameter(np.zeros(matrix_a_shape).astype(np.float32),
                                         name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False)
                matrix_g_inv = Parameter(np.zeros(matrix_g_shape).astype(np.float32),
                                         name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False)
                self.matrix_a = self.matrix_a + (matrix_a_inv,)
                self.matrix_g = self.matrix_g + (matrix_g_inv,)
                self.matrix_a_shape = self.matrix_a_shape + (matrix_a_shape,)
            elif layer_type == Embedding:
                vocab_size = weight_shape[0]
                embedding_size = weight_shape[1]
                matrix_a_inv = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float32)),
                                         name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False)
                matrix_g_inv = Parameter(Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float32)),
                                         name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False)
                self.matrix_a = self.matrix_a + (matrix_a_inv,)
                self.matrix_g = self.matrix_g + (matrix_g_inv,)
                self.matrix_a_shape = self.matrix_a_shape + ((vocab_size,),)

            if layer_type in [Conv, FC, Embedding] and "bias" not in self.params[idx].name.lower():
                self.weight_fim_idx_map = self.weight_fim_idx_map + (self.thor_layer_count,)
                self.thor_layer_count = self.thor_layer_count + 1
                self.weight_layertype_idx_map = self.weight_layertype_idx_map + (layer_type,)
                if layer_type == Conv:
                    self.weight_conv_idx_map = self.weight_conv_idx_map + (self.conv_layer_count,)
                    self.conv_layer_count = self.conv_layer_count + 1
                else:
                    self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
            else:
                self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
                self.weight_fim_idx_map = self.weight_fim_idx_map + (-1,)
                if layer_type == LayerNorm:
                    self.weight_layertype_idx_map = self.weight_layertype_idx_map + (LayerNorm,)
                else:
                    self.weight_layertype_idx_map = self.weight_layertype_idx_map + (Other,)
            # bert.cls1.output_bias: not a network layer, only a trainable param
            if "output_bias" not in self.params[idx].name.lower():
                layer_counter = get_layer_counter(layer_type, layer_counter, self.params, idx)

    def _get_ainv_ginv_list(self, gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce):
        """get matrixA inverse list and matrix G inverse list"""
        for i in range(len(self.params)):
            thor_layer_count = self.weight_fim_idx_map[i]
            conv_layer_count = self.weight_conv_idx_map[i]
            layer_type = self.weight_layertype_idx_map[i]
            if layer_type in [Conv, FC, Embedding]:
                g = gradients[i]
                matrix_a = self.matrix_a_cov[thor_layer_count]
                matrix_g = self.matrix_g_cov[thor_layer_count]
                matrix_a = F.depend(matrix_a, g)
                matrix_g = F.depend(matrix_g, g)
                damping_a = damping_step
                damping_g = damping_step
                feature_map = self.feature_map
                if layer_type == Conv:
                    a_normalizer = self.a_normalizer[conv_layer_count]
                    g_normalizer = self.g_normalizer[conv_layer_count]
                    a_normalizer = F.depend(a_normalizer, g)
                    g_normalizer = F.depend(g_normalizer, g)
                    damping_a = self.mul(damping_step, 1.0 / a_normalizer)
                    damping_g = self.mul(damping_step, 1.0 / g_normalizer)
                    feature_map = self.sqrt(1.0 / a_normalizer)
                a_shape = self.shape(matrix_a)
                a_eye = self.eye(a_shape[0], a_shape[0], mstype.float32)
                damping_a = self.sqrt(damping_a)
                damping_g = self.sqrt(damping_g)
                g_shape = self.shape(matrix_g)
                g_eye = self.eye(g_shape[0], g_shape[1], mstype.float32)
                matrix_g = self.mul(matrix_g, self.loss_scale)
                matrix_g = self.mul(matrix_g, self.batch_size_scale)
                matrix_g = matrix_g + damping_g * g_eye
                if layer_type == Embedding:
                    a_eye = P.OnesLike()(matrix_a)
                    matrix_a = self.mul(matrix_a, 1.0 / self.batch_size)
                    matrix_a = matrix_a + damping_a * a_eye
                    matrix_a = self.inv(matrix_a)
                    matrix_g = self.embedding_cholesky(matrix_g)
                    matrix_g = self.matmul(matrix_g, matrix_g)
                else:
                    matrix_a = matrix_a + damping_a * a_eye
                    matrix_a = self.cholesky(matrix_a)
                    matrix_a = self.vector_matmul(matrix_a, matrix_a)
                    matrix_a = P.BroadcastTo(self.matrix_a_shape[thor_layer_count])(matrix_a)
                    matrix_g = self.cholesky(matrix_g)
                    matrix_g = self.vector_matmul(matrix_g, matrix_g)
                matrix_a = self.mul(matrix_a, feature_map)
                matrix_g = self.mul(matrix_g, feature_map)
                matrix_a_allreduce = matrix_a_allreduce + (matrix_a,)
                matrix_g_allreduce = matrix_g_allreduce + (matrix_g,)
        return matrix_a_allreduce, matrix_g_allreduce

    def _process_layernorm(self, damping_step, gradient):
        """process layernorm"""
        damping = self.sqrt(damping_step)
        normalizer = self.batch_size
        normalizer = self.cast(normalizer, mstype.float32)
        fim_cov = self.square(gradient)
        fim_cov = self.mul(fim_cov, 1.0 / normalizer)
        fim_cov = fim_cov + damping
        fim_inv = self.inv(fim_cov)
        gradient = self.mul(fim_inv, gradient)
        return gradient

    def _reshape_gradient(self, conv_layer_count, g, g_shape):
        """reshape gradient"""
        if conv_layer_count != -1:
            g = self.reshape(g, g_shape)
        return g

    def construct(self, gradients):
        params = self.params
        moments = self.moments
        gradients = self.flatten_gradients(gradients)
        gradients = self.scale_grad(gradients)
        damping_step = self.gather(self.damping, self.cov_step, self.axis)
        damping_step = self.cast(damping_step, mstype.float32)
        new_grads = ()
        if self.thor:
            matrix_ainv_list = ()
            matrix_ginv_list = ()
            matrix_a_allreduce, matrix_g_allreduce = self._get_ainv_ginv_list(gradients, damping_step,
                                                                              matrix_ainv_list, matrix_ginv_list)
            if self.is_distributed:
                matrix_a_allreduce = self.grad_reducer_a(matrix_a_allreduce)
                matrix_g_allreduce = self.grad_reducer_g(matrix_g_allreduce)

            for i in range(len(self.params)):
                g = gradients[i]
                thor_layer_count = self.weight_fim_idx_map[i]
                conv_layer_count = self.weight_conv_idx_map[i]
                layer_type = self.weight_layertype_idx_map[i]
                if layer_type in [Conv, FC]:
                    g_shape = self.shape(g)
                    g = self.reshape(g, (g_shape[0], -1))
                    matrix_a = matrix_a_allreduce[thor_layer_count]
                    matrix_g = matrix_g_allreduce[thor_layer_count]
                    g = self.update_gradient(matrix_g, g, matrix_a)
                    self.assign(self.matrix_a[thor_layer_count], matrix_a)
                    self.assign(self.matrix_g[thor_layer_count], matrix_g)
                    g = self._reshape_gradient(conv_layer_count, g, g_shape)
                elif layer_type == Embedding:
                    matrix_a = matrix_a_allreduce[thor_layer_count]
                    matrix_g = matrix_g_allreduce[thor_layer_count]
                    self.assign(self.matrix_a[thor_layer_count], matrix_a)
                    self.assign(self.matrix_g[thor_layer_count], matrix_g)
                    temp_a = self.expand(matrix_a, 1)
                    g = self.mul(temp_a, g)
                    g = self.matmul(g, matrix_g)
                elif layer_type == LayerNorm:
                    g = self._process_layernorm(damping_step, g)
                new_grads = new_grads + (g,)
        else:
            for j in range(len(self.params)):
                g = gradients[j]
                thor_layer_count = self.weight_fim_idx_map[j]
                conv_layer_count = self.weight_conv_idx_map[j]
                layer_type = self.weight_layertype_idx_map[j]
                if layer_type in [Conv, FC]:
                    g_shape = self.shape(g)
                    g = self.reshape(g, (g_shape[0], -1))
                    matrix_a = self.matrix_a[thor_layer_count]
                    matrix_g = self.matrix_g[thor_layer_count]
                    g = self.update_gradient(matrix_g, g, matrix_a)
                    g = self._reshape_gradient(conv_layer_count, g, g_shape)
                elif layer_type == Embedding:
                    matrix_a = self.matrix_a[thor_layer_count]
                    matrix_g = self.matrix_g[thor_layer_count]
                    g = gradients[j]
                    temp_a = self.expand(matrix_a, 1)
                    g = self.mul(temp_a, g)
                    g = self.matmul(g, matrix_g)
                elif layer_type == LayerNorm:
                    g = self._process_layernorm(damping_step, g)
                new_grads = new_grads + (g,)
        gradients = new_grads

        self.cov_step = self.cov_step + self.one
        if self.weight_decay > 0:
            gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients)
        gradients = clip_gradient(self.enable_clip_grad, gradients)
        lr = self.get_lr()
        self.assignadd(self.global_step, self.global_step_increase_tensor)
        success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
        return success


class ThorAscend(Optimizer):
    """ThorAscend"""

    def __init__(self, net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32,
                 decay_filter=lambda x: x.name not in [], split_indices=None, enable_clip_grad=False, frequency=100):
        params = filter(lambda x: x.requires_grad, net.get_parameters())
        super(ThorAscend, self).__init__(learning_rate, params, weight_decay, loss_scale)
        _check_param(momentum, frequency, learning_rate, self.__class__.__name__)
        self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
        self.params = self._parameters
        self.moments = self.params.clone(prefix="moments", init='zeros')
        self.hyper_map = C.HyperMap()
        self.opt = P.ApplyMomentum()
        self.net = net
        self.matrix_a_cov = ParameterTuple(filter(lambda x: 'matrix_a' in x.name, net.get_parameters()))
        self.matrix_g_cov = ParameterTuple(filter(lambda x: 'matrix_g' in x.name, net.get_parameters()))
        self.a_normalizer = ParameterTuple(filter(lambda x: 'a_normalizer' in x.name, net.get_parameters()))
        self.g_normalizer = ParameterTuple(filter(lambda x: 'g_normalizer' in x.name, net.get_parameters()))
        logger.info("matrix_a_cov len is {}".format(len(self.matrix_a_cov)))
        self._define_ascend_operator()
        self.c0 = 16
        self.device_shape_pad_flag = ()
        self.diag_block_dim = 128
        self.matrix_a = ()
        self.matrix_g = ()
        self.thor_layer_count = 0
        self.conv_layer_count = 0
        self.weight_conv_idx_map = ()
        self.weight_fim_idx_map = ()
        self.weight_layertype_idx_map = ()
        self.a_split_pad_dim_map = ()
        self.g_split_pad_dim_map = ()
        self.conv_matmul_support_map = ()
        self.batch_matmul_support_list = [1, 2, 4, 5, 6, 8, 9, 16, 18, 24, 32, 36]
        self.abs_max_support_list = [1, 2, 4, 8, 16, 5, 9, 18, 36, 32]
        self._process_matrix_init_and_weight_idx_map(self.net)
        self.matrix_a = ParameterTuple(self.matrix_a)
        self.matrix_g = ParameterTuple(self.matrix_g)
        self.matrix_max_inv = ()
        for i in range(len(self.matrix_a)):
            self.matrix_max_inv = self.matrix_max_inv + (
                Parameter(initializer(1, [1], mstype.float32), name='%s%s' % ("matrix_max", str(i)),
                          requires_grad=False),)
        self.matrix_max_inv = ParameterTuple(self.matrix_max_inv)
        self.thor = True
        self.weight_decay = weight_decay
        self.decay_flags = tuple(decay_filter(x) for x in self._parameters)
        self.damping = damping
        self.batch_size = Tensor(batch_size, mstype.float32)
        self.loss_scale = Tensor(1 / (loss_scale * loss_scale), mstype.float32)
        self.batch_size_scale = Tensor(batch_size * batch_size, mstype.float32)
        self.enable_clip_grad = enable_clip_grad
        self.frequency = frequency
        self._define_ascend_reducer(split_indices)

    def get_frequency(self):
        """get thor frequency"""
        return self.frequency

    def _get_pad_dim(self, matrix_dim):
        """get diag split pad dim """
        split_pad_dim = 0
        if matrix_dim == 64:
            return split_pad_dim
        res = matrix_dim % self.diag_block_dim
        if res != 0:
            split_pad_dim = self.diag_block_dim - res
        return split_pad_dim

    def _define_ascend_operator(self):
        """define ascend operator"""
        self.cube_matmul_left = P.CusMatMulCubeFraczLeftCast()
        self.cube_matmul_left_fc = P.CusMatMulCubeDenseLeft()
        self.cube_matmul_right_fc = P.CusMatMulCubeDenseRight()
        self.cube_matmul_right_mul = P.CusMatMulCubeFraczRightMul()
        self.transpose = P.Transpose()
        self.shape = P.Shape()
        self.reshape = P.Reshape()
        self.mul = P.Mul()
        self.log = P.Log()
        self.exp = P.Exp()
        self.sqrt = P.Sqrt()
        self.gather = P.Gather()
        self.assign = P.Assign()
        self.cast = P.Cast()
        self.eye = P.Eye()
        self.concat = P.Concat(0)
        self.cholesky = P.CusCholeskyTrsm()
        self.vector_matmul = P.CusBatchMatMul()
        self.tbe_batch_matmul = P.BatchMatMul(transpose_a=True)
        self.fused_abs_max2 = P.CusFusedAbsMax1()
        self.matrix_combine = P.CusMatrixCombine()
        self.slice = P.Slice()
        self.expand = P.ExpandDims()
        self.reduce_sum = P.ReduceSum(keep_dims=False)
        self.square = P.Square()
        self.inv = P.Inv()
        self.matmul = P.MatMul()
        self.axis = 0
        self.one = Tensor(1, mstype.int32)
        self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)

    def _define_ascend_reducer(self, split_indices):
        """define ascend reducer"""
        self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
        self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
        if self.is_distributed:
            mean = _get_gradients_mean()
            degree = _get_device_num()
            if not split_indices:
                self.split_indices = [len(self.matrix_a_cov) - 1]
            else:
                self.split_indices = split_indices
            if self.conv_layer_count > 0:
                auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum2")
                auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum4")
                self.grad_reducer_amax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=2)
                self.grad_reducer_gmax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=4)

            auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum6")
            auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum8")
            self.grad_reducer_a = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=6)
            self.grad_reducer_g = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=8)

    def _get_weight_idx_map(self, layer_type, idx, weight_shape):
        """for Ascend, get weight idx map"""
        if layer_type in [Conv, FC, Embedding] and "bias" not in self.params[idx].name.lower():
            self.weight_fim_idx_map = self.weight_fim_idx_map + (self.thor_layer_count,)
            self.weight_layertype_idx_map = self.weight_layertype_idx_map + (layer_type,)
            if layer_type == Embedding:
                a_pad_dim = 0
                g_pad_dim = 0
                self.a_split_pad_dim_map = self.a_split_pad_dim_map + (a_pad_dim,)
                self.g_split_pad_dim_map = self.g_split_pad_dim_map + (g_pad_dim,)
            else:
                out_channels = weight_shape[0]
                g_pad_dim = self._get_pad_dim(out_channels)
                self.g_split_pad_dim_map = self.g_split_pad_dim_map + (g_pad_dim,)
                matrix_a_dim = weight_shape[1]
                if layer_type == Conv:
                    matrix_a_dim = weight_shape[1] * weight_shape[2] * weight_shape[3]
                a_pad_dim = self._get_pad_dim(matrix_a_dim)
                self.a_split_pad_dim_map = self.a_split_pad_dim_map + (a_pad_dim,)

            self.thor_layer_count = self.thor_layer_count + 1
            if layer_type == Conv:
                self.weight_conv_idx_map = self.weight_conv_idx_map + (self.conv_layer_count,)
                self.conv_layer_count = self.conv_layer_count + 1
            else:
                self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
        else:
            self.weight_fim_idx_map = self.weight_fim_idx_map + (-1,)
            self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
            if layer_type == LayerNorm:
                self.weight_layertype_idx_map = self.weight_layertype_idx_map + (LayerNorm,)
            else:
                self.weight_layertype_idx_map = self.weight_layertype_idx_map + (Other,)

    def _get_fc_matrix(self, weight_shape):
        """for Ascend, get fc matrix_a and matrix_g"""
        out_channels = weight_shape[0]
        in_channels = weight_shape[1]
        if self.conv_layer_count > 0:
            if out_channels == 1001:
                fc_matrix_a = Parameter(Tensor(np.zeros([128, 128, 16, 16]).astype(np.float16)),
                                        name='matrix_a_inv_' + str(self.thor_layer_count),
                                        requires_grad=False)
                fc_matrix_g = Parameter(Tensor(np.zeros([63, 63, 16, 16]).astype(np.float16)),
                                        name="matrix_g_inv_" + str(self.thor_layer_count),
                                        requires_grad=False)
            else:
                fc_matrix_a = Parameter(Tensor(np.eye(in_channels).astype(np.float16)),
                                        name='matrix_a_inv_' + str(self.thor_layer_count),
                                        requires_grad=False)
                fc_matrix_g = Parameter(Tensor(np.eye(out_channels).astype(np.float16)),
                                        name="matrix_g_inv_" + str(self.thor_layer_count),
                                        requires_grad=False)
            self.matrix_a = self.matrix_a + (fc_matrix_a,)
            self.matrix_g = self.matrix_g + (fc_matrix_g,)

    def _process_matrix_init_and_weight_idx_map(self, net):
        """for Ascend, process matrix init shape, and get weight idx map"""
        layer_counter = 0
        layer_type_map = get_net_layertype_mask(net)
        for idx in range(len(self.params)):
            layer_type = layer_type_map[layer_counter]
            weight = self.params[idx]
            weight_shape = self.shape(weight)
            if layer_type == Conv and "bias" not in self.params[idx].name.lower():
                in_channels = weight_shape[1]
                out_channels = weight_shape[0]
                matrix_a_dim = in_channels * weight_shape[2] * weight_shape[3]
                matrix_g_dim = out_channels
                matrix_a_device_shape, matrix_a_device_dim = caculate_device_shape(matrix_a_dim, in_channels, True)
                matrix_g_device_shape, matrix_g_device_dim = caculate_device_shape(matrix_g_dim, in_channels, False)
                ret = is_conv_matmul_support_shape(matrix_a_device_shape, matrix_g_device_shape)
                if ret:
                    matrix_a_inv = Parameter(
                        Tensor(np.reshape(np.identity(matrix_a_device_dim).astype(np.float16), matrix_a_device_shape)),
                        name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False)
                    matrix_g_inv = Parameter(
                        Tensor(np.reshape(np.identity(matrix_g_device_dim).astype(np.float16), matrix_g_device_shape)),
                        name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False)
                    self.conv_matmul_support_map = self.conv_matmul_support_map + (1,)
                else:
                    matrix_a_inv = Parameter(Tensor(np.eye(matrix_a_dim).astype(np.float16)),
                                             name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False)
                    matrix_g_inv = Parameter(Tensor(np.eye(matrix_g_dim).astype(np.float16)),
                                             name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False)
                    self.conv_matmul_support_map = self.conv_matmul_support_map + (0,)
                self.matrix_a = self.matrix_a + (matrix_a_inv,)
                self.matrix_g = self.matrix_g + (matrix_g_inv,)
                device_shape_pad_flag = False
                if matrix_a_dim != matrix_a_device_dim:
                    device_shape_pad_flag = True
                self.device_shape_pad_flag = self.device_shape_pad_flag + (device_shape_pad_flag,)
            elif layer_type == FC and "bias" not in self.params[idx].name.lower():
                self._get_fc_matrix(weight_shape)
            self._get_weight_idx_map(layer_type, idx, weight_shape)
            # bert.cls1.output_bias: not a network layer, only a trainable param
            if "output_bias" not in self.params[idx].name.lower():
                layer_counter = get_layer_counter(layer_type, layer_counter, self.params, idx)

    def _process_batch_matmul(self, input_matrix):
        """process batch matmul"""
        input_matrix_shape = self.shape(input_matrix)
        if input_matrix_shape[0] in self.batch_matmul_support_list:
            input_matrix = self.vector_matmul(input_matrix, input_matrix)
        else:
            input_matrix = self.tbe_batch_matmul(input_matrix, input_matrix)
        return input_matrix

    def _process_cholesky_pad(self, pad_dim, input_matrix, matrix_shape0):
        """process cholesky pad"""
        if pad_dim > 0:
            matrix_sup = self.eye(pad_dim, pad_dim, mstype.float32)
            matrix_sup = P.Pad(((0, 0), (matrix_shape0, 0)))(matrix_sup)
            input_matrix = P.Pad(((0, 0), (0, pad_dim)))(input_matrix)
            input_matrix = self.concat((input_matrix, matrix_sup))
        return input_matrix

    def _get_abs_max(self, matrix_inv, origin_dim):
        """get matrix abs max"""
        cholesky_shape = self.shape(matrix_inv)
        if cholesky_shape[0] in self.abs_max_support_list:
            matrix_inv_max = P.CusFusedAbsMax1([origin_dim, origin_dim])(matrix_inv)
            matrix_max = self.fused_abs_max2(matrix_inv_max)
            matrix_inv = self.matrix_combine(matrix_inv)
        else:
            matrix_inv = self.matrix_combine(matrix_inv)
            matrix_abs = P.Abs()(matrix_inv)
            matrix_max = P.ReduceMax(keep_dims=False)(matrix_abs)
        return matrix_max, matrix_inv

    def _get_fc_ainv_ginv(self, index, damping_step, gradients, matrix_a_allreduce, matrix_g_allreduce,
                          matrix_a_max_allreduce, matrix_g_max_allreduce):
        """get fc layer ainv and ginv"""
        thor_layer_count = self.weight_fim_idx_map[index]
        g = gradients[index]
        matrix_a = self.matrix_a_cov[thor_layer_count]
        matrix_g = self.matrix_g_cov[thor_layer_count]
        matrix_a = F.depend(matrix_a, g)
        matrix_g = F.depend(matrix_g, g)
        a_shape = self.shape(matrix_a)
        a_eye = self.eye(a_shape[0], a_shape[0], mstype.float32)
        g_shape = self.shape(matrix_g)
        g_eye = self.eye(g_shape[0], g_shape[0], mstype.float32)
        damping = self.sqrt(damping_step)
        matrix_a = matrix_a + damping * a_eye
        a_pad_dim = self.a_split_pad_dim_map[thor_layer_count]
        matrix_a = self._process_cholesky_pad(a_pad_dim, matrix_a, a_shape[0])
        matrix_a_inv = self.cholesky(matrix_a)
        matrix_a_inv = self._process_batch_matmul(matrix_a_inv)

        weight_shape = self.shape(self.params[index])
        out_channels = weight_shape[0]
        in_channels = weight_shape[1]
        if out_channels == 2:
            matrix_a_inv = self.matrix_combine(matrix_a_inv)
            matrix_g_inv = g_eye
        else:
            matrix_g = self.mul(matrix_g, self.loss_scale)
            matrix_g = self.mul(matrix_g, self.batch_size_scale)
            matrix_g = matrix_g + damping * g_eye
            g_pad_dim = self.g_split_pad_dim_map[thor_layer_count]
            matrix_g = self._process_cholesky_pad(g_pad_dim, matrix_g, g_shape[0])
            matrix_g_inv = self.cholesky(matrix_g)
            matrix_g_inv = self._process_batch_matmul(matrix_g_inv)
            if self.conv_layer_count > 0:
                a_max, matrix_a_inv = self._get_abs_max(matrix_a_inv, in_channels)
                g_max, matrix_g_inv = self._get_abs_max(matrix_g_inv, out_channels)
                a_max = F.depend(a_max, g)
                g_max = F.depend(g_max, g)
                matrix_a_max_allreduce = matrix_a_max_allreduce + (a_max,)
                matrix_g_max_allreduce = matrix_g_max_allreduce + (g_max,)
            else:
                matrix_a_inv = self.matrix_combine(matrix_a_inv)
                matrix_g_inv = self.matrix_combine(matrix_g_inv)

            if a_pad_dim > 0:
                matrix_a_inv = self.slice(matrix_a_inv, (0, 0), (in_channels, in_channels))
            if g_pad_dim > 0:
                matrix_g_inv = self.slice(matrix_g_inv, (0, 0), (out_channels, out_channels))
            matrix_a_inv_shape = self.shape(matrix_a_inv)
            matrix_g_combine_shape = self.shape(matrix_g_inv)
            if matrix_a_inv_shape[0] == 2048 and matrix_g_combine_shape[0] == 1001:
                matrix_a_inv = self.reshape(matrix_a_inv,
                                            (matrix_a_inv_shape[0] // 16, 16,
                                             matrix_a_inv_shape[0] // 16, 16))
                matrix_a_inv = self.transpose(matrix_a_inv, (2, 0, 1, 3))
                matrix_g_inv = P.Pad(((0, 7), (0, 7)))(matrix_g_inv)

                matrix_g_inv_shape = self.shape(matrix_g_inv)
                matrix_g_inv = self.reshape(matrix_g_inv,
                                            (matrix_g_inv_shape[0] // 16, 16,
                                             matrix_g_inv_shape[0] // 16, 16))
                matrix_g_inv = self.transpose(matrix_g_inv, (2, 0, 1, 3))

        matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,)
        matrix_g_allreduce = matrix_g_allreduce + (matrix_g_inv,)
        return matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce

    def _process_conv_matmul_device_pad(self, conv_layer_count, weight_shape, matrix_a_inv):
        """process conv matmul device pad"""
        if self.device_shape_pad_flag[conv_layer_count]:
            kernel_hw = weight_shape[2] * weight_shape[3]
            in_channels = weight_shape[1]
            matrix_a_inv = self.reshape(matrix_a_inv, (kernel_hw, in_channels, kernel_hw, in_channels))
            matrix_a_inv = P.Pad(((0, 0), (0, self.c0 - in_channels), (0, 0),
                                  (0, self.c0 - in_channels)))(matrix_a_inv)
        return matrix_a_inv

    def _get_ainv_ginv_amax_gmax_list(self, gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce,
                                      matrix_a_max_allreduce, matrix_g_max_allreduce):
        """get matrixA inverse list, matrixG inverse list, matrixA_max list, matrixG_max list"""
        for i in range(len(self.params)):
            thor_layer_count = self.weight_fim_idx_map[i]
            conv_layer_count = self.weight_conv_idx_map[i]
            layer_type = self.weight_layertype_idx_map[i]
            weight_shape = self.shape(self.params[i])
            out_channels = weight_shape[0]
            if layer_type == Conv:
                g = gradients[i]
                matrix_a_dim = weight_shape[1] * weight_shape[2] * weight_shape[3]
                matmul_support_flag = self.conv_matmul_support_map[conv_layer_count]
                matrix_a = self.matrix_a_cov[thor_layer_count]
                matrix_g = self.matrix_g_cov[thor_layer_count]
                matrix_a = F.depend(matrix_a, g)
                matrix_g = F.depend(matrix_g, g)
                a_shape = self.shape(matrix_a)
                a_eye = self.eye(a_shape[0], a_shape[0], mstype.float32)
                g_shape = self.shape(matrix_g)
                g_eye = self.eye(g_shape[0], g_shape[0], mstype.float32)
                a_normalizer = self.a_normalizer[conv_layer_count]
                g_normalizer = self.g_normalizer[conv_layer_count]
                a_normalizer = F.depend(a_normalizer, g)
                g_normalizer = F.depend(g_normalizer, g)
                damping_a = self.mul(damping_step, self.batch_size / a_normalizer)
                damping_g = self.mul(damping_step, self.batch_size / g_normalizer)
                damping_a = self.sqrt(damping_a)
                matrix_a = matrix_a + damping_a * a_eye
                a_pad_dim = self.a_split_pad_dim_map[thor_layer_count]
                matrix_a = self._process_cholesky_pad(a_pad_dim, matrix_a, a_shape[0])
                matrix_a_inv = self.cholesky(matrix_a)
                matrix_a_inv = self._process_batch_matmul(matrix_a_inv)
                a_max, matrix_a_inv = self._get_abs_max(matrix_a_inv, matrix_a_dim)

                damping_g = self.sqrt(damping_g)
                matrix_g = self.mul(matrix_g, self.loss_scale)
                matrix_g = self.mul(matrix_g, self.batch_size_scale)
                matrix_g = matrix_g + damping_g * g_eye
                g_pad_dim = self.g_split_pad_dim_map[thor_layer_count]
                matrix_g = self._process_cholesky_pad(g_pad_dim, matrix_g, g_shape[0])
                matrix_g_inv = self.cholesky(matrix_g)
                matrix_g_inv = self._process_batch_matmul(matrix_g_inv)
                g_max, matrix_g_inv = self._get_abs_max(matrix_g_inv, out_channels)

                if a_pad_dim > 0:
                    matrix_a_inv = self.slice(matrix_a_inv, (0, 0), (matrix_a_dim, matrix_a_dim))
                if g_pad_dim > 0:
                    matrix_g_inv = self.slice(matrix_g_inv, (0, 0), (out_channels, out_channels))

                if matmul_support_flag == 1:
                    matrix_a_inv = self._process_conv_matmul_device_pad(conv_layer_count, weight_shape, matrix_a_inv)
                    matrix_a_inv_shape = self.shape(self.matrix_a[thor_layer_count])
                    matrix_a_device_temp_shape = (matrix_a_inv_shape[0], matrix_a_inv_shape[2],
                                                  matrix_a_inv_shape[1], matrix_a_inv_shape[3])
                    matrix_a_inv = self.reshape(matrix_a_inv, matrix_a_device_temp_shape)
                    matrix_a_inv = self.transpose(matrix_a_inv, (2, 0, 1, 3))
                    matrix_g_inv_shape = self.shape(self.matrix_g[thor_layer_count])
                    matrix_g_device_temp_shape = (matrix_g_inv_shape[0], matrix_g_inv_shape[2],
                                                  matrix_g_inv_shape[1], matrix_g_inv_shape[3])
                    matrix_g_inv = self.reshape(matrix_g_inv, matrix_g_device_temp_shape)
                    matrix_g_inv = self.transpose(matrix_g_inv, (2, 0, 1, 3))

                a_max = F.depend(a_max, g)
                g_max = F.depend(g_max, g)
                matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,)
                matrix_g_allreduce = matrix_g_allreduce + (matrix_g_inv,)
                matrix_a_max_allreduce = matrix_a_max_allreduce + (a_max,)
                matrix_g_max_allreduce = matrix_g_max_allreduce + (g_max,)
            elif layer_type == FC:
                matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce = \
                    self._get_fc_ainv_ginv(i, damping_step, gradients, matrix_a_allreduce, matrix_g_allreduce,
                                           matrix_a_max_allreduce, matrix_g_max_allreduce)
            elif layer_type == Embedding:
                g = gradients[i]
                matrix_a = self.matrix_a_cov[thor_layer_count]
                matrix_g = self.matrix_g_cov[thor_layer_count]
                matrix_a = F.depend(matrix_a, g)
                matrix_g = F.depend(matrix_g, g)
                g_shape = self.shape(matrix_g)
                g_eye = self.eye(g_shape[0], g_shape[0], mstype.float32)
                damping = self.sqrt(damping_step)
                a_eye = P.OnesLike()(matrix_a)
                matrix_a = self.mul(matrix_a, 1.0 / self.batch_size)
                matrix_a = matrix_a + damping * a_eye
                matrix_a_inv = self.inv(matrix_a)
                matrix_g = self.mul(matrix_g, self.loss_scale)
                matrix_g = self.mul(matrix_g, self.batch_size_scale)
                matrix_g = matrix_g + damping * g_eye
                matrix_g_inv = self.cholesky(matrix_g)
                matrix_g_inv = self._process_batch_matmul(matrix_g_inv)
                matrix_g_inv = self.matrix_combine(matrix_g_inv)
                matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,)
                matrix_g_allreduce = matrix_g_allreduce + (matrix_g_inv,)
        return matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce

    def _process_layernorm(self, damping_step, gradient):
        """process layernorm layer for thor"""
        damping = self.sqrt(damping_step)
        normalizer = self.cast(self.batch_size, mstype.float32)
        fim_cov = self.square(gradient)
        fim_cov = self.mul(fim_cov, 1.0 / normalizer)
        fim_cov = fim_cov + damping
        fim_inv = self.inv(fim_cov)
        gradient = self.mul(fim_inv, gradient)
        return gradient

    def _process_thor_fc(self, thor_layer_count, matrix_a_allreduce, matrix_g_allreduce, g):
        """process thor graph fc layer"""
        temp_a = matrix_a_allreduce[thor_layer_count]
        temp_g = matrix_g_allreduce[thor_layer_count]
        self.assign(self.matrix_a_cov[thor_layer_count], temp_a)
        self.assign(self.matrix_g_cov[thor_layer_count], temp_g)
        temp_a = self.cast(temp_a, mstype.float16)
        temp_g = self.cast(temp_g, mstype.float16)
        g = self.cast(g, mstype.float16)
        g = self.matmul(temp_g, g)
        g = self.matmul(g, temp_a)
        g = self.cast(g, mstype.float32)
        return g

    def _get_second_gradients_one(self, params_len, gradients, new_grads):
        """get second gradients one"""
        for i in range(params_len):
            g = gradients[i]
            thor_layer_count = self.weight_fim_idx_map[i]
            conv_layer_count = self.weight_conv_idx_map[i]
            layer_type = self.weight_layertype_idx_map[i]
            matrix_a = self.matrix_a[thor_layer_count]
            matrix_g = self.matrix_g[thor_layer_count]
            matrix_max = self.matrix_max_inv[thor_layer_count]
            grad_shape = self.shape(g)
            if layer_type == FC:
                if grad_shape[0] == 1001:
                    g = self.cube_matmul_left_fc(matrix_g, g)
                    g = self.cube_matmul_right_fc(g, matrix_a, matrix_max)
                else:
                    temp_a = self.cast(matrix_a, mstype.float16)
                    temp_g = self.cast(matrix_g, mstype.float16)
                    g = self.cast(g, mstype.float16)
                    g = self.matmul(temp_g, g)
                    g = self.matmul(g, temp_a)
                    g = self.cast(g, mstype.float32)
                    g = self.mul(g, matrix_max)
            elif layer_type == Conv:
                matmul_support_flag = self.conv_matmul_support_map[conv_layer_count]
                if matmul_support_flag == 1:
                    g = self.cube_matmul_left(matrix_g, g)
                    g = self.cube_matmul_right_mul(g, matrix_a, matrix_max)
                else:
                    g = self.reshape(g, (grad_shape[0], grad_shape[1] * grad_shape[2] * grad_shape[3]))
                    temp_a = self.cast(matrix_a, mstype.float16)
                    temp_g = self.cast(matrix_g, mstype.float16)
                    g = self.cast(g, mstype.float16)
                    g = self.matmul(temp_g, g)
                    g = self.matmul(g, temp_a)
                    g = self.cast(g, mstype.float32)
                    g = self.mul(g, matrix_max)
                    g = self.reshape(g, grad_shape)
            new_grads = new_grads + (g,)
        return new_grads

    def _get_second_gradients(self, new_grads, damping_step, gradients):
        """get second gradients for thor"""
        params_len = len(self.params)
        if self.conv_layer_count > 0:
            new_grads = self._get_second_gradients_one(params_len, gradients, new_grads)
        else:
            for i in range(params_len):
                g = gradients[i]
                thor_layer_count = self.weight_fim_idx_map[i]
                layer_type = self.weight_layertype_idx_map[i]
                if layer_type == Embedding:
                    temp_a_ori = self.matrix_a_cov[thor_layer_count]
                    temp_g = self.matrix_g_cov[thor_layer_count]
                    temp_a = self.expand(temp_a_ori, 1)
                    g = self.mul(temp_a, g)
                    temp_g = self.cast(temp_g, mstype.float16)
                    g = self.cast(g, mstype.float16)
                    g = self.matmul(g, temp_g)
                    g = self.cast(g, mstype.float32)
                elif layer_type == FC:
                    temp_a = self.matrix_a_cov[thor_layer_count]
                    temp_g = self.matrix_g_cov[thor_layer_count]
                    temp_a = self.cast(temp_a, mstype.float16)
                    temp_g = self.cast(temp_g, mstype.float16)
                    g = self.cast(g, mstype.float16)
                    g = self.matmul(temp_g, g)
                    g = self.matmul(g, temp_a)
                    g = self.cast(g, mstype.float32)
                elif layer_type == LayerNorm:
                    g = self._process_layernorm(damping_step, g)
                new_grads = new_grads + (g,)
        return new_grads

    def _get_second_grad_by_matmul(self, index, temp_a, temp_g, g, temp_max):
        """get second gradient by matmul"""
        conv_layer_count = self.weight_conv_idx_map[index]
        layer_type = self.weight_layertype_idx_map[index]
        grad_shape = self.shape(g)
        if layer_type == FC:
            if grad_shape[0] == 1001:
                g = self.cube_matmul_left_fc(temp_g, g)
                g = self.cube_matmul_right_fc(g, temp_a, temp_max)
            else:
                temp_a = self.cast(temp_a, mstype.float16)
                temp_g = self.cast(temp_g, mstype.float16)
                g = self.cast(g, mstype.float16)
                g = self.matmul(temp_g, g)
                g = self.matmul(g, temp_a)
                g = self.cast(g, mstype.float32)
                g = self.mul(g, temp_max)
        elif layer_type == Conv:
            a_normalizer = self.a_normalizer[conv_layer_count]
            a_normalizer = F.depend(a_normalizer, g)
            temp_max = self.mul(temp_max, self.batch_size / a_normalizer)
            matmul_support_flag = self.conv_matmul_support_map[conv_layer_count]
            if matmul_support_flag == 1:
                g = self.cube_matmul_left(temp_g, g)
                g = self.cube_matmul_right_mul(g, temp_a, temp_max)
            else:
                g = self.reshape(g, (grad_shape[0], grad_shape[1] * grad_shape[2] * grad_shape[3]))
                temp_a = self.cast(temp_a, mstype.float16)
                temp_g = self.cast(temp_g, mstype.float16)
                g = self.cast(g, mstype.float16)
                g = self.matmul(temp_g, g)
                g = self.matmul(g, temp_a)
                g = self.cast(g, mstype.float32)
                g = self.mul(g, temp_max)
                g = self.reshape(g, grad_shape)
        return g, temp_max

    def _get_second_grad_by_layertype(self, index, matrix_a_allreduce, matrix_g_allreduce, g, damping_step):
        """get second gradient by layertype"""
        thor_layer_count = self.weight_fim_idx_map[index]
        layer_type = self.weight_layertype_idx_map[index]
        if layer_type == Embedding:
            temp_a_ori = matrix_a_allreduce[thor_layer_count]
            temp_g = matrix_g_allreduce[thor_layer_count]
            self.assign(self.matrix_a_cov[thor_layer_count], temp_a_ori)
            self.assign(self.matrix_g_cov[thor_layer_count], temp_g)
            temp_a = self.expand(temp_a_ori, 1)
            g = self.mul(temp_a, g)
            temp_g = self.cast(temp_g, mstype.float16)
            g = self.cast(g, mstype.float16)
            g = self.matmul(g, temp_g)
            g = self.cast(g, mstype.float32)
        elif layer_type == FC:
            g = self._process_thor_fc(thor_layer_count, matrix_a_allreduce, matrix_g_allreduce, g)
        elif layer_type == LayerNorm:
            g = self._process_layernorm(damping_step, g)
        return g

    def construct(self, gradients):
        params = self.params
        moments = self.moments
        gradients = self.flatten_gradients(gradients)
        gradients = self.scale_grad(gradients)
        damping_step = self.gather(self.damping, self.cov_step, self.axis)
        damping_step = self.cast(damping_step, mstype.float32)
        if self.thor:
            matrix_a_allreduce = ()
            matrix_g_allreduce = ()
            matrix_a_max_allreduce = ()
            matrix_g_max_allreduce = ()
            matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce = \
                self._get_ainv_ginv_amax_gmax_list(gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce,
                                                   matrix_a_max_allreduce, matrix_g_max_allreduce)
            if self.is_distributed:
                matrix_a_allreduce = self.grad_reducer_a(matrix_a_allreduce)
                matrix_g_allreduce = self.grad_reducer_g(matrix_g_allreduce)
                if self.conv_layer_count > 0:
                    matrix_a_max_allreduce = self.grad_reducer_amax(matrix_a_max_allreduce)
                    matrix_g_max_allreduce = self.grad_reducer_gmax(matrix_g_max_allreduce)

            new_grads = ()
            if self.conv_layer_count > 0:
                for i in range(len(self.params)):
                    g = gradients[i]
                    thor_layer_count = self.weight_fim_idx_map[i]
                    temp_a = matrix_a_allreduce[thor_layer_count]
                    temp_g = matrix_g_allreduce[thor_layer_count]
                    matrix_a_inv_max = self.log(matrix_a_max_allreduce[thor_layer_count])
                    matrix_a_inv_max = self.mul(matrix_a_inv_max, -1)
                    matrix_a_inv_max = self.exp(matrix_a_inv_max)
                    temp_a = self.mul(temp_a, matrix_a_inv_max)
                    matrix_g_inv_max = self.log(matrix_g_max_allreduce[thor_layer_count])
                    matrix_g_inv_max = self.mul(matrix_g_inv_max, -1)
                    matrix_g_inv_max = self.exp(matrix_g_inv_max)
                    temp_g = self.mul(temp_g, matrix_g_inv_max)
                    temp_max = self.mul(matrix_g_max_allreduce[thor_layer_count],
                                        matrix_g_max_allreduce[thor_layer_count])
                    temp_a = self.cast(temp_a, mstype.float16)
                    temp_g = self.cast(temp_g, mstype.float16)
                    g, temp_max = self._get_second_grad_by_matmul(i, temp_a, temp_g, g, temp_max)
                    self.assign(self.matrix_a[thor_layer_count], temp_a)
                    self.assign(self.matrix_g[thor_layer_count], temp_g)
                    self.assign(self.matrix_max_inv[thor_layer_count], temp_max)
                    new_grads = new_grads + (g,)
                gradients = new_grads
            else:
                for i in range(len(self.params)):
                    g = gradients[i]
                    g = self._get_second_grad_by_layertype(i, matrix_a_allreduce, matrix_g_allreduce, g, damping_step)
                    new_grads = new_grads + (g,)
                gradients = new_grads
        else:
            new_grads = ()
            gradients = self._get_second_gradients(new_grads, damping_step, gradients)

        self.cov_step = self.cov_step + self.one
        if self.weight_decay > 0:
            gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients)
        gradients = clip_gradient(self.enable_clip_grad, gradients)
        lr = self.get_lr()
        self.assignadd(self.global_step, self.global_step_increase_tensor)
        success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
        return success
